from memory_profiler import profile
import os
#import tempfile
#from subprocess import Popen, PIPE
import numpy as np
import keras

#import shutil
#import pandas as pd
#import warnings
#import datetime
from skimage.util import view_as_windows as vaw

os.environ['SIDEKIT'] = 'theano=false,libsvm=false'
#from sidekit.frontend.io import read_wav
from sidekit.frontend.features import mfcc

from pyannote.algorithms.utils.viterbi import viterbi_decoding
from .viterbi_utils import pred2logemission, diag_trans_exp, log_trans_exp
import librosa
import gc
import sys

import tensorflow as tf
graph = tf.get_default_graph()

@profile
def _wav2feats(wavname):
    ext = os.path.splitext(wavname)[-1]
    assert ext.lower() == '.wav' or ext.lower() == '.wave'
    sampwidth = 2

    sig, read_framerate = get_voice_ratedata(wavname)
    print('sig', sys.getrefcount(sig))
    sig *= (2**(15-sampwidth))
    print('sig', sys.getrefcount(sig))
    X, loge, Y, mspec = mfcc(sig.astype(np.float32), get_mspec=True)
    gc.collect()
    return mspec, loge

@profile
def get_voice_ratedata(wavname):
    sig, read_framerate = librosa.load(wavname, sr=16000)
    print('sig', sys.getrefcount(sig))
    print('read_framerate', sys.getrefcount(read_framerate))
    return sig, read_framerate

# import librosa
# import sys
# from memory_profiler import profile
# wavname="/usr/lzj/Pro/ASR_python/n1.wav"
# sig, read_framerate = librosa.load(wavname, sr=16000)

@profile
def _energy_activity(loge, ratio=0.03):
    threshold = np.mean(loge[np.isfinite(loge)]) + np.log(ratio)
    raw_activity = (loge > threshold)
    return viterbi_decoding(pred2logemission(raw_activity),
                            log_trans_exp(150, cost0=-5))

@profile
def _get_patches(mspec, w, step):
    h = mspec.shape[1]
    data = vaw(mspec, (w,h), step=step)
    data.shape = (len(data), w*h)
    data = (data - np.mean(data, axis=1).reshape((len(data), 1))) / np.std(data, axis=1).reshape((len(data), 1))
    lfill = [data[0,:].reshape(1, h*w)] * (w // (2 * step))
    rfill = [data[-1,:].reshape(1, h*w)] * (w // (2* step) - 1 + len(mspec) % 2)
    data = np.vstack(lfill + [data] + rfill )
    finite = np.all(np.isfinite(data), axis=1)
    data.shape = (len(data), w, h)
    return data, finite
@profile
def _binidx2seglist(binidx):
    curlabel = None
    bseg = -1
    ret = []
    for i, e in enumerate(binidx):
        if e != curlabel:
            if curlabel is not None:
                ret.append((curlabel, bseg, i))
            curlabel = e
            bseg = i
    ret.append((curlabel, bseg, i + 1))
    return ret




@profile
def segmentwav(wavname):
    mspec, loge = _wav2feats(wavname)
    difflen = 0
    if len(loge) < 68:
        difflen = 68 - len(loge)
        mspec = np.concatenate((mspec, np.ones((difflen, 24)) * np.min(mspec)))
    lseg = []
    newlab = ''
    newloss =''
    newstart=''
    newstop=''
    for lab, start, stop in _binidx2seglist(_energy_activity(loge)[::2]):
        if lab == 0:
            lab = 'noEnergy'
        else:
            lab = 'speech'#energy
            if newlab=='':
                newlab = lab
                newloss =stop - start
                newstart = start
                newstop = stop
            elif (stop - start)<newloss:
                newlab = lab
                newloss = stop - start
                newstart = start
                newstop = stop
    lseg.append(('speech', newstart, newstop))
    lseg = gender(mspec, lseg, difflen)
    #del mspec, loge
    #gc.collect()
    return [(lab, start * .02, stop * .02) for lab, start, stop in lseg]


def Segmenter(voicefile):
   # ffmpeg = 'ffmpeg'
    base, _ = os.path.splitext(os.path.basename(voicefile))
    return segmentwav(voicefile)


The_p = os.path.dirname(os.path.realpath(__file__)) + '/'
The_nn = keras.models.load_model(The_p + 'keras_male_female_cnn.hdf5', compile=False)
print('The_nn',sys.getrefcount(The_nn))
@profile
def get_predict(patches,start,stop):
    New_NN = The_nn
    global graph
    with graph.as_default():
        rawpred = The_nn.predict(patches[start:stop, :])
    del New_NN
    #gc.collect()
    return rawpred
@profile
def gender(mspec, lseg, difflen):
    outlabels = ('female', 'male')
    #model_fname = 'keras_male_female_cnn.hdf5'
    #new_nn = The_nn
    inlabel = 'speech'
    nmel = 24
    viterbi_arg = 80
    if nmel < 24:
        mspec = mspec[:, nmel].copy()

    patches, finite = _get_patches(mspec, 68, 2)
    if difflen > 0:
        patches = patches[:-int(difflen / 2), :, :]
        finite = finite[:-int(difflen / 2)]
    #print('AAAAAAA')
    assert len(finite) == len(patches), (len(patches), len(finite))
   # print('ssssddddddddf')
    ret = []
    for lab, start, stop in lseg:
        if lab != inlabel:
            continue
        if (stop - start) > 10:
            stop = start + 10

        rawpred = get_predict(patches,start,stop)
        rawpred[finite[start:stop] == False, :] = 0.5
        pred = viterbi_decoding(np.log(rawpred), diag_trans_exp(viterbi_arg, len(outlabels)))
        #del rawpred
        #gc.collect()
        del rawpred
        gc.collect()
        for lab2, start2, stop2 in _binidx2seglist(pred):
            ret.append((outlabels[int(lab2)], start2 + start, stop2 + start))
           # print('ret',ret)
            #del The_nn

            return ret
    return ret
